Skip to content

[WIP] Feat/gpt oss example#63

Draft
ymwangg wants to merge 7 commits into
mainfrom
feat/gpt-oss-example
Draft

[WIP] Feat/gpt oss example#63
ymwangg wants to merge 7 commits into
mainfrom
feat/gpt-oss-example

Conversation

@ymwangg

@ymwangg ymwangg commented Jun 26, 2026

Copy link
Copy Markdown
Contributor

Add p-eagle gpt-oss-20b example.

ymwangg and others added 7 commits June 25, 2026 13:38
Add a NKIPy example for OpenAI's gpt-oss MoE models (gpt-oss-20b / 120b),
mirroring the qwen3 example structure. The implementation is fully
config-driven, so both sizes share one codebase.

gpt-oss-specific handling:
- MXFP4 experts dequantized to bf16 at prep time
- interleaved gate/up de-interleaved at prep time
- clamped SwiGLU with gate_up/down biases
- per-head attention sinks + QKV/O biases (no QK-norm)
- alternating sliding-window / full attention (one kernel per type)
- YaRN RoPE (inv_freq precomputed from HF config)
- router with top-k-then-softmax and router bias

Validated against HF on trn2 (TP=4): every generated token matches HF's
argmax or a bf16-resolution tie.
Implements parallel-drafting P-EAGLE (arXiv 2602.01469) on top of the
gpt-oss base model for speculative decoding on Trainium.

Components added (examples/models/gpt_oss/eagle/):
- config.py: EagleConfig for the 4-layer P-EAGLE drafter (llama3 RoPE,
  fc fusion, mask_hidden/ptd_token_id, d2t vocab map)
- tensor_preparation.py: convert P-EAGLE checkpoint to x@W form (replicated)
- kernels/drafter.py: parallel-drafting forward - K tokens in one pass via
  NTP (real hidden) + MTP (mask_hidden) positions with cross-depth mask
- kernels/drafter_layer.py: EAGLE-3 fusion midlayer + plain Llama layers
- kernels/verify.py: multi-position greedy argmax for verification
- drafter_model.py: device-side drafter model + compile
- speculate.py: full speculation loop (prefill → draft → verify → accept)

Base model changes:
- config.py: added aux_layers config + default_aux_layers() for EAGLE-3 taps
- gpt_oss.py: run_prefill() now optionally captures pre-layer hidden states
  at the 3 EAGLE-3 tap layers (2, L/2, L-3)
- kernels/attention.py: generalized decode path to support seq_len>1 (for
  the multi-token verify pass) via query_pos = start_pos + arange(seq_len)

Status: functionally correct (lossless greedy output verified against HF).
Acceptance length is below the paper's reported ~3.3 — under investigation
(likely a hidden-state position/timing issue in the draft-verify loop seeding).
…yers

Switch aux capture to post-layer (output of tap layers 2/12/21) based on
HF validation showing the drafter predicts correctly with HF's hidden
states at hs[3]/hs[13]/hs[22] (output of layers 2/12/21).

Note: acceptance length remains low (~1.0) due to numerical divergence
between nkipy's Neuron-compiled target and the HF CPU reference the
drafter was trained against. The drafter kernel is mathematically correct
(validated against independent torch reference) and correctly predicts
the target when fed exact HF hidden states. The gap is an
implementation-coupling issue inherent to EAGLE-style speculation.
Key findings from the P-EAGLE paper (Figure 2, Figure 3, Section 3):

1. The drafter maintains its own KV cache across the full context
   (prompt + all accepted tokens). At each draft step, K positions
   attend to the FULL accumulated cache.

2. The attention mask is GROUP-CAUSAL: all K positions see the full
   cache (group 0), but within the K positions the NTP (group 1)
   and MTP (group 2+) positions use cross-depth causality — MTP
   positions cannot attend to positions at the same or later depth.

3. The NTP pair is (emb(t_n), hidden_after_processing_t_{n-1}),
   predicting t_{n+1}. The hidden is one step behind the embedding.

This commit adds:
- drafter_cpu.py: CPU reference drafter with full KV cache and
  standard causal attention (working infrastructure, mask needs
  the group-causal refinement for MTP positions)
- Fixes hidden state capture to post-layer (output of tap layers)
- Adds peagle_aux_layers config method

Status: KV cache infrastructure correct, still needs the group-causal
mask refinement for the MTP positions within the K-wide draft window.
Root-caused the low acceptance length (~1.4 vs the card's 3.30-3.80 at K=7)
on GPU by running the identical checkpoint through vLLM's eagle3
parallel-drafting path, capturing its drafter I/O, and reproducing it with a
standalone PyTorch reference (cosine 0.9999, 100% draft-token match). Three
bugs plus a prompt-formatting issue:

1. Context-blind drafting (dominant): speculate.py drove DrafterModel
   (kernels/drafter.py), which runs only the K draft positions under a (K,K)
   cross-depth mask with no prefill and no KV cache, so the MTP slots never
   saw the prompt. Rewired speculate.py to use the KV-cached DrafterCPU:
   prefill the drafter on the prompt (EAGLE +1 shift), then each step roll the
   cache back to the last accepted position and run [newly-accepted tokens |
   K-1 ptd slots] in one parallel forward attending to the full context.

2. rollback() truncated the wrong axis: the cache is (B, n_kv, seq, head_dim)
   and rollback sliced dim 1 (n_kv) instead of dim 2 (seq), so rejected
   speculative KV was never discarded and corrupted later steps.

3. Aux tap off-by-one: vLLM's eagle3 default (2, n//2, n-3) captures the
   residual stream entering those layers; our post-layer capture must shift
   down one, so default_aux_layers now returns (1, 11, 20). Verified on GPU
   the drafter's 3 fc chunks equal target layer outputs (1, 11, 20) at cos 1.0.

4. Prompt formatting: the drafter is trained on chat data; raw prompts roughly
   halve acceptance (GPU, K=7: 3.65 chat vs 1.99 raw). speculate.py now applies
   the chat template by default (--raw-prompt to opt out).

Still produces all K draft tokens in a single forward pass (parallel drafting).
Adds test_drafter_cpu.py guarding the rollback/full-context invariants (skips
without the checkpoint). Validated against vLLM on GPU; not yet re-validated on
Trainium, and the on-device kernels/drafter.py KV-cache port remains follow-up.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@ymwangg

ymwangg commented Jun 30, 2026

Copy link
Copy Markdown
Contributor Author

gpt-oss + P-EAGLE — Implementation Status Report

Status as of the current branch feat/gpt-oss-example. Target: gpt-oss-20b on AWS
Trainium (TP=4) via NKIPy.

TL;DR

  • Base gpt-oss-20b: complete and working on Trainium (~70 tok/s decode).
  • P-EAGLE speculative decoding: numerically correct and fully on-device. The
    drafter port (the recent work) is done and validated — device drafter is
    byte-identical to the GPU-validated CPU reference and ~6× faster than it.
  • Open issue: P-EAGLE end-to-end is currently a net slowdown vs the base
    decode (~0.57×)
    despite accepting ~3.4 tokens/step, because the speculative
    loop is synchronous while the base decode is async/double-buffered. The drafter
    is no longer the bottleneck — the loop is.

Components

Component File(s) Status
Base MoE model (TP, MXFP4→bf16, sinks, sliding window, YaRN RoPE, double-buffered decode) gpt_oss.py, kernels/, tensor_preparation.py ✅ Working on hardware
Drafter — CPU reference (KV-cached full-context parallel drafting) eagle/drafter_cpu.py ✅ GPU-validated vs vLLM
Drafter — on-device (KV-cached forward + draft head) eagle/drafter_model.py, eagle/kernels/drafter*.py ✅ Ported + Trainium-validated
Speculation loop (prefill → draft → verify → accept) eagle/speculate.py ✅ Working; --device-drafter selects on-device path
Verify kernel (K+1 multi-token target pass + per-pos argmax) eagle/kernels/verify.py ✅ Working
Tests eagle/test_drafter_cpu.py, eagle/test_drafter_device.py ✅ Pass (real + synthetic)

What changed recently (this work)

The drafter had a working CPU reference but the on-device path was the old
context-blind kernel
(K positions under a static (K,K) cross-depth mask, no
prefill, no KV cache) and was unused. Three commits brought the device path up to
parity and tightened metrics:

Commit Summary
fdb13e0 fix: correct P-EAGLE drafter to full-context KV-cached parallel drafting (CPU reference + loop bookkeeping; GPU-validated vs vLLM)
e8da8b1 feat: port P-EAGLE drafter to on-device KV-cached parallel drafting (DrafterModel + kernels/drafter*.py; --device-drafter; test_drafter_device.py)
2cd9e2f fix: correct acceptance-length metric to count only emitted tokens (excludes the final step truncated by the token budget)

Port design note: the device drafter keeps static kernel shapes — it commits
newly-accepted tokens one at a time (S=1) and runs the K-wide draft window
(S=K) — so it avoids per-step recompiles. Causal attention makes the
one-at-a-time commits numerically identical to a single batched forward.

Also fixed along the way: speculate.py's chat-template branch crashed on recent
transformers (it returns a BatchEncoding, not an array); now extracts
["input_ids"]. The template ships separately as chat_template.jinja.

Validation

Check Result
CPU drafter math vs vLLM (GPU) cos 0.9999 prefill; 100% draft-token match
Device drafter vs CPU drafter (synthetic, Trainium) logit cos 0.9999 across prefill + KV-cached draft steps with rollback
Device vs CPU drafter, end-to-end (real checkpoints) byte-identical output
test_drafter_cpu.py (real checkpoint) 3/3 pass
test_drafter_device.py (Neuron-gated) pass

Under greedy verification emitted tokens are the target's argmax regardless of
the drafter, so identical output is expected by construction — the drafter only
affects speed. Acceptance length is the meaningful correctness signal (a broken /
context-blind drafter collapses toward ~1); we observe ~2.4–4.6.

Performance (5-prompt sweep, n=128, K=7, TP=4, chat template, identical env)

Config Mean decode tok/s vs base Mean acceptance
Base (no speculation) 69.8 1.00×
P-EAGLE, device drafter 39.8 0.57× 3.44
P-EAGLE, CPU drafter 6.6 0.10× 3.37
  • Device drafter is ~6× faster than the CPU drafter (the port's contribution).
  • Speedup tracks acceptance: 0.72× on high-acceptance prompts (haiku 4.59,
    binary-search 4.38) down to 0.44× on low-acceptance ones (transformer 2.44).
  • None exceed 1× — see the bottleneck below.

(Full per-prompt table in gpt-oss-peagle-results.md.)

Known issues / open work

  1. Speculation is a net slowdown vs base decode (primary issue). The base path
    is async / double-buffered (overlapped execution); the speculative loop is
    synchronous, paying every step for a K+1-wide verify pass plus host-side
    orchestration (accept/reject compare, embedding lookups, drafter call) that is
    not pipelined. To make P-EAGLE a net win:

    • async / double-buffered speculation (overlap verify + draft + sampling);
    • move per-step host orchestration on-device to cut host round-trips;
    • tune K per workload (lower K helps low-acceptance prompts).
  2. Acceptance below model card on some prompts. Chat formatting lifts
    acceptance (raw prompts ≈2.95 → chat 4.1–4.6 on favorable prompts), but
    reasoning-heavy prompts still sit ~2.4–2.5. Worth checking tap-layer choice and
    prompt distribution against the card's eval setup.

  3. Single-run benchmarks. Numbers are one run per (prompt, config) — good for
    relative comparison; average several runs before quoting as benchmarks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant